In [ ]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import glob
import matplotlib.image as mpimg
In [ ]:
# ---------------------------
# 1. Hyperparameters
# ---------------------------
EPOCHS = 550 # Total epochs to train
BATCH_SIZE = 128
IMAGE_SIZE = 32
CHANNELS_IMG = 3
LATENT_DIM = 100
EMBED_DIM = 50 # Dimension for label embedding
CRITIC_ITER = 5 # Critic iterations per generator iteration
LAMBDA_GP = 10 # Gradient penalty lambda
LEARNING_RATE = 2e-4
BETA1, BETA2 = 0.5, 0.999
CHECKPOINT_EVERY = 20 # Save, sample images, compute IS/FID every 50 epochs
AUTOMOBILE_CLASS_IDX = 1
In [ ]:
# ---------------------------
# 2. Data Loading (CIFAR-10)
# ---------------------------
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
# Filter to get only automobile images
automobile_indices = [i for i, (_, label) in enumerate(trainset) if label == AUTOMOBILE_CLASS_IDX]
automobile_dataset = Subset(trainset, automobile_indices)
# Create dataloader with only automobile images
trainloader = DataLoader(
automobile_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
In [ ]:
device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: mps
In [ ]:
# ---------------------------
# 3. Models: Generator & Discriminator (Critic)
# ---------------------------
class Generator(nn.Module):
def __init__(self, latent_dim, embed_dim, num_classes=10):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
# Label embedding
self.label_emb = nn.Embedding(num_classes, embed_dim)
# Project + reshape to 4x4
self.fc = nn.Sequential(
nn.Linear(latent_dim + embed_dim, 4*4*512),
nn.BatchNorm1d(4*4*512),
nn.ReLU(True)
)
# Upsampling (transposed conv) to produce 32x32 images
self.deconv = nn.Sequential(
# 4x4 -> 8x8
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 8x8 -> 16x16
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 16x16 -> 32x32
nn.ConvTranspose2d(128, CHANNELS_IMG, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, z, labels):
# z: (batch_size, latent_dim)
# labels: (batch_size)
label_emb = self.label_emb(labels) # (batch_size, embed_dim)
concat = torch.cat([z, label_emb], dim=1) # (batch_size, latent_dim+embed_dim)
x = self.fc(concat) # (batch_size, 4*4*512)
x = x.view(-1, 512, 4, 4) # (batch_size, 512, 4, 4)
out = self.deconv(x) # (batch_size, 3, 32, 32)
return out
In [ ]:
class Discriminator(nn.Module):
def __init__(self, embed_dim, num_classes=10):
super(Discriminator, self).__init__()
self.num_classes = num_classes
# Label embedding
self.label_emb = nn.Embedding(num_classes, embed_dim)
self.conv = nn.Sequential(
nn.Conv2d(CHANNELS_IMG, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
# For 32x32 images, final conv output is (512,2,2) => 512*2*2=2048
self.fc = nn.Sequential(
nn.Linear(512*2*2 + embed_dim, 1)
)
def forward(self, x, labels):
batch_size = x.shape[0]
features = self.conv(x) # (batch_size, 512, 2, 2)
features = features.view(batch_size, -1) # (batch_size, 2048)
label_emb = self.label_emb(labels) # (batch_size, embed_dim)
concat = torch.cat([features, label_emb], dim=1) # (batch_size, 2048 + embed_dim)
validity = self.fc(concat) # (batch_size, 1)
return validity
In [ ]:
# ---------------------------
# 4. Gradient Penalty Function
# ---------------------------
def compute_gradient_penalty(critic, real_imgs, fake_imgs, labels, device="cpu"):
alpha = torch.rand(real_imgs.size(0), 1, 1, 1, device=device)
alpha = alpha.expand_as(real_imgs)
interpolates = alpha * real_imgs + (1 - alpha) * fake_imgs
interpolates = interpolates.requires_grad_(True)
critic_interpolates = critic(interpolates, labels)
grads = torch.autograd.grad(
outputs=critic_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(critic_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
grads = grads.view(grads.size(0), -1)
grad_norm = grads.norm(2, dim=1)
gradient_penalty = ((grad_norm - 1) ** 2).mean()
return gradient_penalty
In [ ]:
# ---------------------------
# 5. Initialize Models & Optimizers
# ---------------------------
gen = Generator(latent_dim=LATENT_DIM, embed_dim=EMBED_DIM).to(device)
crit = Discriminator(embed_dim=EMBED_DIM).to(device)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
opt_crit = optim.Adam(crit.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
start_epoch = 1
checkpoint_path = "adl_part_1.pt"
In [ ]:
# ---------------------------
# 6. Check for Existing Checkpoint
# ---------------------------
if os.path.exists(checkpoint_path):
print("Checkpoint found. Loading...")
checkpoint = torch.load(checkpoint_path, map_location=device)
gen.load_state_dict(checkpoint["gen_state_dict"])
crit.load_state_dict(checkpoint["crit_state_dict"])
opt_gen.load_state_dict(checkpoint["opt_gen_state_dict"])
opt_crit.load_state_dict(checkpoint["opt_crit_state_dict"])
start_epoch = checkpoint["epoch"] + 1
print(f"Resuming from epoch {start_epoch}.")
Checkpoint found. Loading... Resuming from epoch 2.
In [ ]:
# ---------------------------
# 7. Utility: Generate & Display Images for "Automobile"
# ---------------------------
def generate_and_show_samples(epoch, num_samples=10):
gen.eval()
with torch.no_grad():
z = torch.randn(num_samples, LATENT_DIM).to(device)
labels = torch.full((num_samples,), AUTOMOBILE_CLASS_IDX, dtype=torch.long, device=device)
fake_images = gen(z, labels).cpu()
fake_images = (fake_images + 1) / 2.0 # Scale from [-1,1] to [0,1]
fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*2.2, 2.2))
for i in range(num_samples):
img = fake_images[i].permute(1, 2, 0).numpy()
axes[i].imshow(img)
axes[i].axis('off')
plt.suptitle(f"Epoch {epoch}: Automobile Class Samples", fontsize=14)
plt.savefig(f'task1/automobile_gan_losses_{epoch}.png')
plt.show()
gen.train()
In [ ]:
# ---------------------------
# 8. Compute IS & FID with TorchMetrics
# ---------------------------
def compute_is_fid(generator, loader, n_samples=2000):
is_metric = InceptionScore().to("cpu")
fid_metric = FrechetInceptionDistance().to("cpu")
generator.eval()
real_count = 0
for real_imgs, _ in loader:
real_imgs = real_imgs.to(device)
real_imgs_uint8 = (((real_imgs * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
fid_metric.update(real_imgs_uint8, real=True)
real_count += real_imgs.size(0)
if real_count >= n_samples:
break
fake_count = 0
while fake_count < n_samples:
z = torch.randn(BATCH_SIZE, LATENT_DIM, device=device)
labels = torch.randint(0, 10, (BATCH_SIZE,), dtype=torch.long, device=device)
with torch.no_grad():
fake_out = generator(z, labels)
fake_out_uint8 = (((fake_out * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
is_metric.update(fake_out_uint8)
fid_metric.update(fake_out_uint8, real=False)
fake_count += BATCH_SIZE
inception_score = is_metric.compute() # (mean, std)
fid_score = fid_metric.compute()
generator.train()
return inception_score[0].item(), fid_score.item()
print(f"Starting training from epoch {start_epoch} to {EPOCHS}...")
Starting training from epoch 2 to 550...
In [ ]:
# ---------------------------
# 9. Training Loop
# ---------------------------
for epoch in range(start_epoch, EPOCHS + 1):
for batch_idx, (real_imgs, labels) in enumerate(trainloader):
real_imgs = real_imgs.to(device)
labels = labels.to(device)
# -----------------------
# Train Critic: K=5 iterations
# -----------------------
for _ in range(CRITIC_ITER):
z = torch.randn(real_imgs.size(0), LATENT_DIM).to(device)
gen_labels = torch.randint(0, 10, (real_imgs.size(0),), dtype=torch.long, device=device)
fake_imgs = gen(z, gen_labels)
real_validity = crit(real_imgs, labels)
fake_validity = crit(fake_imgs.detach(), gen_labels)
gp = compute_gradient_penalty(crit, real_imgs, fake_imgs.detach(), labels, device)
loss_crit = -torch.mean(real_validity) + torch.mean(fake_validity) + LAMBDA_GP * gp
opt_crit.zero_grad()
loss_crit.backward()
opt_crit.step()
# -----------------------
# Train Generator: 1 iteration
# -----------------------
z = torch.randn(real_imgs.size(0), LATENT_DIM).to(device)
gen_labels = torch.randint(0, 10, (real_imgs.size(0),), dtype=torch.long, device=device)
generated_imgs = gen(z, gen_labels)
fake_validity = crit(generated_imgs, gen_labels)
loss_gen = -torch.mean(fake_validity)
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
print(f"[Epoch {epoch}/{EPOCHS}] Loss Crit: {loss_crit.item():.4f} Loss Gen: {loss_gen.item():.4f}")
# -----------------------
# Save checkpoint, visualize samples, and compute IS/FID every CHECKPOINT_EVERY epochs
# -----------------------
if epoch % CHECKPOINT_EVERY == 0:
checkpoint_data = {
"epoch": epoch,
"gen_state_dict": gen.state_dict(),
"crit_state_dict": crit.state_dict(),
"opt_gen_state_dict": opt_gen.state_dict(),
"opt_crit_state_dict": opt_crit.state_dict()
}
torch.save(checkpoint_data, checkpoint_path)
print(f"[epoch={epoch}]Checkpoint saved: {checkpoint_path}")
generate_and_show_samples(epoch, num_samples=10)
is_val, fid_val = compute_is_fid(gen, trainloader, n_samples=2000)
print(f"==> Epoch {epoch}: Inception Score = {is_val:.4f}, FID = {fid_val:.4f}")
print("Training complete!")
==> Epoch 400: Inception Score = 3.7266, FID = 114.5365 [Epoch 401/550] Loss Crit: -24808.4199 Loss Gen: 13574.6162 [Epoch 402/550] Loss Crit: -28544.2969 Loss Gen: 13639.5449 [Epoch 403/550] Loss Crit: -28645.6289 Loss Gen: 13723.1406 [Epoch 404/550] Loss Crit: -25218.1562 Loss Gen: 13741.7598 [Epoch 405/550] Loss Crit: -28948.5801 Loss Gen: 13839.0479 [Epoch 406/550] Loss Crit: -29050.4785 Loss Gen: 13941.4092 [Epoch 407/550] Loss Crit: -29194.2051 Loss Gen: 10296.4570 [Epoch 408/550] Loss Crit: -18304.2715 Loss Gen: 14030.1875 [Epoch 409/550] Loss Crit: -25827.9648 Loss Gen: 14085.8076 [Epoch 410/550] Loss Crit: -29603.1445 Loss Gen: 14140.7402 [Epoch 411/550] Loss Crit: -29727.2871 Loss Gen: 14233.7188 [Epoch 412/550] Loss Crit: -22439.9258 Loss Gen: 14299.4873 [Epoch 413/550] Loss Crit: -26277.9141 Loss Gen: 14354.5234 [Epoch 414/550] Loss Crit: -22621.7656 Loss Gen: 14435.1035 [Epoch 415/550] Loss Crit: -30260.2422 Loss Gen: 14505.7754 [Epoch 416/550] Loss Crit: -26660.2266 Loss Gen: 10805.9102 [Epoch 417/550] Loss Crit: -30535.2637 Loss Gen: 10804.8086 [Epoch 418/550] Loss Crit: -26888.9824 Loss Gen: 10836.1748 [Epoch 419/550] Loss Crit: -30821.7051 Loss Gen: 10971.6523 [Epoch 420/550] Loss Crit: -31010.6426 Loss Gen: 14857.1289 [epoch=420]Checkpoint saved: adl_part_1.pt
==> Epoch 420: Inception Score = 3.6283, FID = 112.3696 [Epoch 421/550] Loss Crit: -27253.3301 Loss Gen: 11029.1328 [Epoch 422/550] Loss Crit: -23461.2402 Loss Gen: 15000.4277 [Epoch 423/550] Loss Crit: -31436.6230 Loss Gen: 15014.4023 [Epoch 424/550] Loss Crit: -31590.4980 Loss Gen: 11153.7773 [Epoch 425/550] Loss Crit: -31707.7012 Loss Gen: 15177.0449 [Epoch 426/550] Loss Crit: -31838.0449 Loss Gen: 11263.9355 [Epoch 427/550] Loss Crit: -19998.7988 Loss Gen: 11342.3867 [Epoch 428/550] Loss Crit: -32146.8809 Loss Gen: 11364.6348 [Epoch 429/550] Loss Crit: -32277.8711 Loss Gen: 15510.6211 [Epoch 430/550] Loss Crit: -32349.4453 Loss Gen: 11453.3887 [Epoch 431/550] Loss Crit: -32579.1484 Loss Gen: 7473.4731 [Epoch 432/550] Loss Crit: -20437.2324 Loss Gen: 11572.3457 [Epoch 433/550] Loss Crit: -32847.5859 Loss Gen: 15733.1777 [Epoch 434/550] Loss Crit: -32973.5039 Loss Gen: 15831.9258 [Epoch 435/550] Loss Crit: -16577.8633 Loss Gen: 15892.6602 [Epoch 436/550] Loss Crit: -29160.5879 Loss Gen: 15941.3564 [Epoch 437/550] Loss Crit: -33458.1016 Loss Gen: 16042.4814 [Epoch 438/550] Loss Crit: -29349.4941 Loss Gen: 7706.1177 [Epoch 439/550] Loss Crit: -25297.2637 Loss Gen: 11951.4561 [Epoch 440/550] Loss Crit: -29666.3516 Loss Gen: 7740.7476 [epoch=440]Checkpoint saved: adl_part_1.pt
==> Epoch 440: Inception Score = 3.7752, FID = 119.9983 [Epoch 441/550] Loss Crit: -25510.5234 Loss Gen: 12061.8154 [Epoch 442/550] Loss Crit: -34159.0508 Loss Gen: 7863.9141 [Epoch 443/550] Loss Crit: -34291.4531 Loss Gen: 12208.3359 [Epoch 444/550] Loss Crit: -30166.7266 Loss Gen: 16515.2344 [Epoch 445/550] Loss Crit: -34637.5898 Loss Gen: 16574.9746 [Epoch 446/550] Loss Crit: -34745.4727 Loss Gen: 16657.0840 [Epoch 447/550] Loss Crit: -34864.0430 Loss Gen: 16780.8789 [Epoch 448/550] Loss Crit: -30689.7051 Loss Gen: 12464.6318 [Epoch 449/550] Loss Crit: -35224.4219 Loss Gen: 12458.5625 [Epoch 450/550] Loss Crit: -35354.7539 Loss Gen: 16983.4902 [Epoch 451/550] Loss Crit: -35556.6094 Loss Gen: 12589.6895 [Epoch 452/550] Loss Crit: -31194.4980 Loss Gen: 17097.1445 [Epoch 453/550] Loss Crit: -35806.5820 Loss Gen: 8227.9834 [Epoch 454/550] Loss Crit: -35961.4805 Loss Gen: 17277.6836 [Epoch 455/550] Loss Crit: -27090.8008 Loss Gen: 12781.1982 [Epoch 456/550] Loss Crit: -36250.1016 Loss Gen: 12883.2256 [Epoch 457/550] Loss Crit: -36407.1836 Loss Gen: 8348.8711 [Epoch 458/550] Loss Crit: -32039.4609 Loss Gen: 12957.5771 [Epoch 459/550] Loss Crit: -36742.1836 Loss Gen: 13044.7842 [Epoch 460/550] Loss Crit: -36869.3594 Loss Gen: 13033.3271 [epoch=460]Checkpoint saved: adl_part_1.pt
==> Epoch 460: Inception Score = 3.6439, FID = 113.2828 [Epoch 461/550] Loss Crit: -37030.6094 Loss Gen: 17749.9883 [Epoch 462/550] Loss Crit: -32535.0762 Loss Gen: 17839.3965 [Epoch 463/550] Loss Crit: -32678.5527 Loss Gen: 17934.4102 [Epoch 464/550] Loss Crit: -28102.0332 Loss Gen: 13275.2773 [Epoch 465/550] Loss Crit: -37642.6133 Loss Gen: 18093.8887 [Epoch 466/550] Loss Crit: -37811.0430 Loss Gen: 13378.0547 [Epoch 467/550] Loss Crit: -28494.9082 Loss Gen: 18207.5371 [Epoch 468/550] Loss Crit: -33369.0352 Loss Gen: 8744.5342 [Epoch 469/550] Loss Crit: -38273.7969 Loss Gen: 13573.7891 [Epoch 470/550] Loss Crit: -28847.4609 Loss Gen: 4012.7524 [Epoch 471/550] Loss Crit: -38580.0234 Loss Gen: 13661.3613 [Epoch 472/550] Loss Crit: -29104.6641 Loss Gen: 18593.0156 [Epoch 473/550] Loss Crit: -38947.0820 Loss Gen: 13805.4697 [Epoch 474/550] Loss Crit: -39055.2695 Loss Gen: 18759.1836 [Epoch 475/550] Loss Crit: -24526.0664 Loss Gen: 9031.8398 [Epoch 476/550] Loss Crit: -39387.1367 Loss Gen: 13975.5869 [Epoch 477/550] Loss Crit: -39527.3086 Loss Gen: 14031.3770 [Epoch 478/550] Loss Crit: -39715.7930 Loss Gen: 19056.3711 [Epoch 479/550] Loss Crit: -34876.2070 Loss Gen: 19143.6309 [Epoch 480/550] Loss Crit: -35002.4297 Loss Gen: 19181.0078 [epoch=480]Checkpoint saved: adl_part_1.pt
==> Epoch 480: Inception Score = 3.4911, FID = 122.3214 [Epoch 481/550] Loss Crit: -35066.7461 Loss Gen: 14287.7051 [Epoch 482/550] Loss Crit: -30273.4023 Loss Gen: 19318.3672 [Epoch 483/550] Loss Crit: -35434.3750 Loss Gen: 4263.6377 [Epoch 484/550] Loss Crit: -40652.5547 Loss Gen: 14434.2314 [Epoch 485/550] Loss Crit: -35743.4688 Loss Gen: 14482.2559 [Epoch 486/550] Loss Crit: -35838.4062 Loss Gen: 4318.8652 [Epoch 487/550] Loss Crit: -35977.8711 Loss Gen: 14628.6016 [Epoch 488/550] Loss Crit: -36116.2500 Loss Gen: 19758.1406 [Epoch 489/550] Loss Crit: -36262.7539 Loss Gen: 19839.0000 [Epoch 490/550] Loss Crit: -36437.5156 Loss Gen: 19982.1641 [Epoch 491/550] Loss Crit: -41747.2422 Loss Gen: 20077.3926 [Epoch 492/550] Loss Crit: -41986.9531 Loss Gen: 14908.4902 [Epoch 493/550] Loss Crit: -42096.1797 Loss Gen: 14992.3818 [Epoch 494/550] Loss Crit: -42303.6289 Loss Gen: 20296.9863 [Epoch 495/550] Loss Crit: -42423.7500 Loss Gen: 15094.2832 [Epoch 496/550] Loss Crit: -37282.2461 Loss Gen: 20498.1875 [Epoch 497/550] Loss Crit: -37384.9141 Loss Gen: 20547.9219 [Epoch 498/550] Loss Crit: -37610.2031 Loss Gen: 15267.0469 [Epoch 499/550] Loss Crit: -37680.9062 Loss Gen: 15322.2686 [Epoch 500/550] Loss Crit: -37861.6172 Loss Gen: 20798.2656 [epoch=500]Checkpoint saved: adl_part_1.pt
==> Epoch 500: Inception Score = 3.9482, FID = 123.5886 [Epoch 501/550] Loss Crit: -32555.6348 Loss Gen: 20809.9355 [Epoch 502/550] Loss Crit: -43564.2773 Loss Gen: 20938.0742 [Epoch 503/550] Loss Crit: -38290.8672 Loss Gen: 21029.7031 [Epoch 504/550] Loss Crit: -43967.1680 Loss Gen: 21131.1465 [Epoch 505/550] Loss Crit: -44105.5312 Loss Gen: 15688.3613 [Epoch 506/550] Loss Crit: -44284.9883 Loss Gen: 15746.9805 [Epoch 507/550] Loss Crit: -44495.1367 Loss Gen: 21304.9082 [Epoch 508/550] Loss Crit: -44624.3281 Loss Gen: 21417.5293 [Epoch 509/550] Loss Crit: -44804.8867 Loss Gen: 10336.5684 [Epoch 510/550] Loss Crit: -39340.1523 Loss Gen: 15948.5625 [Epoch 511/550] Loss Crit: -45087.6836 Loss Gen: 10435.0703 [Epoch 512/550] Loss Crit: -33946.4219 Loss Gen: 21760.8633 [Epoch 513/550] Loss Crit: -45449.6641 Loss Gen: 21833.7852 [Epoch 514/550] Loss Crit: -39939.1875 Loss Gen: 21903.2969 [Epoch 515/550] Loss Crit: -40047.6953 Loss Gen: 21999.4258 [Epoch 516/550] Loss Crit: -40217.8242 Loss Gen: 16379.8184 [Epoch 517/550] Loss Crit: -34593.7578 Loss Gen: 22208.5820 [Epoch 518/550] Loss Crit: -34720.4961 Loss Gen: 16452.2461 [Epoch 519/550] Loss Crit: -46507.3828 Loss Gen: 22368.8574 [Epoch 520/550] Loss Crit: -46655.8594 Loss Gen: 10749.6729 [epoch=520]Checkpoint saved: adl_part_1.pt
==> Epoch 520: Inception Score = 3.7669, FID = 113.3025 [Epoch 521/550] Loss Crit: -40962.5391 Loss Gen: 4973.6323 [Epoch 522/550] Loss Crit: -41141.8828 Loss Gen: 16749.5000 [Epoch 523/550] Loss Crit: -41324.2422 Loss Gen: 10863.4189 [Epoch 524/550] Loss Crit: -41400.8789 Loss Gen: 22704.7988 [Epoch 525/550] Loss Crit: -35598.8750 Loss Gen: 10900.7471 [Epoch 526/550] Loss Crit: -47670.1758 Loss Gen: 16954.3691 [Epoch 527/550] Loss Crit: -41862.9727 Loss Gen: 16990.3125 [Epoch 528/550] Loss Crit: -48039.3789 Loss Gen: 17058.2070 [Epoch 529/550] Loss Crit: -48183.4492 Loss Gen: 17105.9453 [Epoch 530/550] Loss Crit: -36282.7188 Loss Gen: 11188.9043 [Epoch 531/550] Loss Crit: -48537.9023 Loss Gen: 17303.6270 [Epoch 532/550] Loss Crit: -48800.4453 Loss Gen: 23466.0508 [Epoch 533/550] Loss Crit: -48924.7227 Loss Gen: 23480.7422 [Epoch 534/550] Loss Crit: -42978.7734 Loss Gen: 17509.1660 [Epoch 535/550] Loss Crit: -49290.7344 Loss Gen: 23679.0039 [Epoch 536/550] Loss Crit: -43255.1055 Loss Gen: 17559.1348 [Epoch 537/550] Loss Crit: -43474.0078 Loss Gen: 23888.0293 [Epoch 538/550] Loss Crit: -43573.5664 Loss Gen: 23938.4062 [Epoch 539/550] Loss Crit: -50023.0859 Loss Gen: 24081.0508 [Epoch 540/550] Loss Crit: -43889.2422 Loss Gen: 24082.7461 [epoch=540]Checkpoint saved: adl_part_1.pt
==> Epoch 540: Inception Score = 3.7531, FID = 108.8194 [Epoch 541/550] Loss Crit: -31465.7891 Loss Gen: 24237.8613 [Epoch 542/550] Loss Crit: -50488.6562 Loss Gen: 5347.4287 [Epoch 543/550] Loss Crit: -50698.7773 Loss Gen: 18079.8125 [Epoch 544/550] Loss Crit: -50831.2930 Loss Gen: 24476.7188 [Epoch 545/550] Loss Crit: -51058.2930 Loss Gen: 24590.9941 [Epoch 546/550] Loss Crit: -51230.6055 Loss Gen: 11811.0273 [Epoch 547/550] Loss Crit: -32166.0312 Loss Gen: 24732.5977 [Epoch 548/550] Loss Crit: -51645.7656 Loss Gen: 24823.6035 [Epoch 549/550] Loss Crit: -51777.6797 Loss Gen: 24917.9492 [Epoch 550/550] Loss Crit: -51947.8008 Loss Gen: 11993.6816 Training complete!
In [ ]:
directory = r'task1'
# Define a custom sort key that extracts the epoch number
def extract_epoch(filename):
base = os.path.basename(filename)
try:
# Assuming filename format: automobile_samples_epoch_{epoch_number}.png
epoch_str = base.split('automobile_gan_losses_')[1].split('.')[0]
return int(epoch_str)
except (IndexError, ValueError):
return float('inf') # Place any files that don't match the pattern at the end
png_files = glob.glob(os.path.join(directory, '*.png'))
# Sort the list numerically by epoch number
png_files = sorted(png_files, key=extract_epoch)
# Check if any PNG files are found
if not png_files:
print("No PNG files found in the directory:", directory)
else:
n = len(png_files)
# Increase the figure size to accommodate full screen-like display
fig, axs = plt.subplots(n, 1, figsize=(22, 2.4 * n))
# If only one image, wrap axs into a list for consistency
if n == 1:
axs = [axs]
mng = plt.get_current_fig_manager()
try:
mng.window.state('zoomed')
except AttributeError:
try:
mng.window.showMaximized()
except Exception:
pass # If it fails, the figure will remain at the set figsize
for ax, file in zip(axs, png_files):
img = mpimg.imread(file)
ax.imshow(img, aspect='auto')
ax.axis('off')
ax.set_title(os.path.basename(file), fontsize=14)
plt.tight_layout()
plt.show()
In [ ]:
# Install necessary packages
!apt-get install texlive texlive-xetex texlive-latex-extra pandoc
!pip install pypandoc
# Mount Google Drive
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
# Copy the notebook to the current directory
!cp 'drive/My Drive/Colab Notebooks/Assignment2_Group75_Task1.ipynb' ./
# Convert the notebook to PDF while keeping the code and output
!jupyter nbconvert --to pdf "Assignment2_Group75_Task1.ipynb"
# Download the generated PDF
from google.colab import files
files.download('Assignment2_Group75_Task1.pdf')